source('codes/00_import_libraries.R')


#------------------------------------------------------------------------------#
# data -------------------------------------------------------------------------
#------------------------------------------------------------------------------#

dt = readRDS('input/combined_data.rds')[['base']] %>% 
  rename(PLS = PLS_All_recursive)


recess = read_excel('input/us recession dummy.xlsx') %>% 
  mutate(
    date = as.Date(as.yearmon(ym)),
    xmin = date,
    xmax = lead(date),
    ymin = -0.4,
    ymax = 1.0
  ) %>% 
  filter(as.yearmon(date) >= 'Jan 1871')



#------------------------------------------------------------------------------#
# function ---------------------------------------------------------------------
#------------------------------------------------------------------------------#

fn_cum_is_r2 = function(dt, yvar, xvar) {
  
  df = dt %>% select(ym, xvar, yvar) %>% drop_na()
  x = df[[xvar]]
  y = df[[yvar]]
  y_bar = mean(y)
  
  e_hat = lm(y ~ x) %>% resid() %>% unname()
  cum_r2 = ( cumsum((y - y_bar)^2) - cumsum(e_hat^2) ) / sum((y - y_bar)^2)
  out = data.frame(df$ym, cum_r2) %>% set_names(c('ym', paste0(xvar, '_r2')))
  
  return(out)
}



#------------------------------------------------------------------------------#
# apply ------------------------------------------------------------------------
#------------------------------------------------------------------------------#

x_list = c('War', 'PLS')
yvar = 'mkt1'


gplot = lapply(x_list, function(xvar) {
  fn_cum_is_r2(dt = dt, yvar = yvar, xvar = xvar)
}) %>% 
  
  reduce(left_join, by = 'ym') %>% 
  pivot_longer(
      !ym, names_to  = 'var', values_to = 'val'
  ) %>% 
  
  mutate(
    ym   = as.yearmon(ym),
    date = as.Date(ym),  
    val  = val * 100,
    var  = factor(
      var, 
      levels = paste0(x_list, '_r2'), 
      labels = x_list %>% str_replace_all('\\d+', '')
    ),
    lab = if_else(ym == max(ym), as.character(var), '')
  ) %>% 
  
  ggplot(aes(x = date, y = val, linetype = var)) +
  
  geom_rect(
      data = recess, aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = as.factor(recess)), 
      inherit.aes = FALSE, alpha = 1
  ) +
  scale_fill_manual(values = c('white', 'grey')) +
  
  geom_line(size = 1) +
  
  labs(title = '', y = '%', x = '') +
  scale_y_continuous(breaks = seq(-0.2, 1, 0.2)) +
  scale_x_date(
      breaks = seq.Date(as.Date('1870-01-01'), as.Date('2020-01-01'), by = '10 years'), 
      date_labels = '%Y', expand = c(0.05, 0)
  ) +
  
  geom_text_repel(
      aes(label = lab), max.overlaps = 200, size = 7, fontface = 'plain', nudge_x = 10
  ) +
  
  theme_base() +
  theme(
    legend.position = 'none',
    text            = element_text(size = 25),
    plot.title      = element_text(hjust = 0.5),
    # , plot.margin     = unit(c(0,0,0,0), 'pt')
    plot.background = element_rect(colour = NA)
  )


ggsave('output/figures/Figure_5A.tiff', gplot, width = 12, height = 6, device='tiff', dpi=300)


